import numpy as np
import torch
import ntk_utils
import pandas as pd
from pylab import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet',
					help='model architecture')
parser.add_argument('--dataset', default='cifar10', type=str,
					help='which dataset used to train')

parser.add_argument('--exp', default='fgsm', type=str,
					help='exp name')

parser.add_argument('--test', default=0, type=int,
					help='if on test set')
args = parser.parse_args()


method = args.exp

epoch_list = [i for i in range(1,10)] + [i for i in range(10, 201, 10)]


if args.test == 0:
	matrix_ae_clean_path = './%s/%s/' % (args.dataset, args.arch) + '%s/matrix_ae_clean%d.pt'
	matrix_ae_pgd_path = './%s/%s/' % (args.dataset, args.arch) + '%s/matrix_ae_pgd%d.pt'

	clean_label_path = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_clean%d_clean.pt'

	clean_label_path_pgd = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_pgd%d_clean.pt'
	ae_label_path_pgd = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_pgd%d_ae.pt'
else:
	matrix_ae_clean_path = './%s/%s/'%(args.dataset, args.arch) + '%s/matrix_ae_clean%d_test.pt'
	matrix_ae_pgd_path = './%s/%s/'%(args.dataset, args.arch) + '%s/matrix_ae_pgd%d_test.pt'

	clean_label_path = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_clean%d_clean_test.pt'

	clean_label_path_pgd = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_pgd%d_clean_test.pt'
	ae_label_path_pgd = './%s/%s/' % (args.dataset, args.arch) + '%s/label_ae_pgd%d_ae_test.pt'



ae_clean_erank_list = []

ae_pgd_erank_list = []

ae_clean_distance_list = []

ae_pgd_distance_list = []


for epoch in epoch_list:
	print(epoch)
	matrix_ae_clean = torch.load(matrix_ae_clean_path%(method,epoch)).float()
	matrix_ae_pgd = torch.load(matrix_ae_pgd_path%(method,epoch)).float()

	class_matrix_ae_clean = ntk_utils.calculate_class_average_matrix(matrix_ae_clean)
	class_matrix_ae_pgd = ntk_utils.calculate_class_average_matrix(matrix_ae_pgd)

	u, s, vh = np.linalg.svd(class_matrix_ae_clean.numpy())
	mask = np.where(s > 0)
	ae_clean_erank_list.append(ntk_utils.erank(s[mask]))

	u, s, vh = np.linalg.svd(class_matrix_ae_pgd.numpy())
	mask = np.where(s > 0)
	ae_pgd_erank_list.append(ntk_utils.erank(s[mask]))


for i in range(len(epoch_list)-1):
	epoch = epoch_list[i]
	epoch_next = epoch_list[i+1]
	print(epoch, epoch_next)
	matrix_ae_clean = torch.load(matrix_ae_clean_path%(method,epoch)).float()
	matrix_ae_pgd = torch.load(matrix_ae_pgd_path%(method,epoch)).float()

	matrix_ae_clean_next = torch.load(matrix_ae_clean_path%(method,epoch_next))
	matrix_ae_pgd_next = torch.load(matrix_ae_pgd_path%(method,epoch_next))

	class_matrix_ae_clean = ntk_utils.calculate_class_average_matrix(matrix_ae_clean)
	class_matrix_ae_pgd = ntk_utils.calculate_class_average_matrix(matrix_ae_pgd)

	class_matrix_ae_clean_next = ntk_utils.calculate_class_average_matrix(matrix_ae_clean_next)
	class_matrix_ae_pgd_next = ntk_utils.calculate_class_average_matrix(matrix_ae_pgd_next)

	ae_clean_distance_list.append(ntk_utils.cal_kernel_distance(class_matrix_ae_clean.numpy(), class_matrix_ae_clean_next.numpy()))
	ae_pgd_distance_list.append(ntk_utils.cal_kernel_distance(class_matrix_ae_pgd.numpy(), class_matrix_ae_pgd_next.numpy()))


for epoch in epoch_list:
	print(epoch)
	matrix_ae_clean = torch.load(matrix_ae_clean_path%(method,epoch)).float().numpy()
	matrix_ae_pgd = torch.load(matrix_ae_pgd_path%(method,epoch)).float().numpy()

	clean_label = torch.load(clean_label_path%(method, epoch)).numpy().astype(np.int32)

	clean_label_pgd = torch.load(clean_label_path_pgd%(method, epoch)).numpy().astype(np.int32)
	ae_label_pgd = torch.load(ae_label_path_pgd%(method, epoch)).numpy().astype(np.int32)

	class_num = matrix_ae_clean.shape[-1]
	ksm_ae_clean = np.zeros((class_num, class_num))

	ksm_ae_pgd_clean_label = np.zeros((class_num, class_num))
	ksm_ae_pgd = np.zeros((class_num, class_num))

	for i in range(class_num):
		for j in range(class_num):
			ksm_ae_clean[i, j] = ntk_utils.ksm(matrix_ae_clean, i, j, clean_label)

			ksm_ae_pgd_clean_label[i, j] = ntk_utils.ksm(matrix_ae_pgd, i, j, clean_label_pgd)
			ksm_ae_pgd[i, j] = ntk_utils.ksm(matrix_ae_pgd, i, j, ae_label_pgd)

	df_ae_clean = pd.DataFrame(ksm_ae_clean, index=np.arange(class_num))

	df_ae_pgd_clean_label = pd.DataFrame(ksm_ae_pgd_clean_label, index=np.arange(class_num))
	df_ae_pgd = pd.DataFrame(ksm_ae_pgd, index=np.arange(class_num))

	if args.test == 0:
		df_ae_clean.to_csv('./%s/%s/' % (args.dataset, args.arch) + '%s/ksm_ae_clean_%d.csv' % (method, epoch))

		df_ae_pgd_clean_label.to_csv('./%s/%s/' % (args.dataset, args.arch) + '%s/ksm_ae_pgd_%d_clean_label.csv' % (method, epoch))
		df_ae_pgd.to_csv('./%s/%s/' % (args.dataset, args.arch) + '%s/ksm_ae_pgd_%d_ae_label.csv' % (method, epoch))
	else:
		df_ae_clean.to_csv('./%s/%s/'%(args.dataset, args.arch) + '%s/ksm_ae_clean_%d_test.csv'%(method, epoch))

		df_ae_pgd_clean_label.to_csv('./%s/%s/'%(args.dataset, args.arch) + '%s/ksm_ae_pgd_%d_clean_label_test.csv' % (method, epoch))
		df_ae_pgd.to_csv('./%s/%s/'%(args.dataset, args.arch) + '%s/ksm_ae_pgd_%d_ae_label_test.csv' % (method, epoch))

print('ae_vs_clean erank:', ae_clean_erank_list)
print('ae_vs_pgd erank:', ae_pgd_erank_list)


print('\n')

print('ae_vs_clean Dis:', ae_clean_distance_list)
print('ae_vs_pgd Dis:', ae_pgd_distance_list)

print('\n')

if args.test == 0:
	matrix_ae_clean_path = './%s/%s/' % (args.dataset, args.arch) + '%s/ae_clean%s.pt'
	matrix_ae_pgd_path = './%s/%s/' % (args.dataset, args.arch) + '%s/ae_pgd%s.pt'
else:
	matrix_ae_clean_path = './%s/%s/'%(args.dataset, args.arch) + '%s/ae_clean%s_test.pt'
	matrix_ae_pgd_path = './%s/%s/'%(args.dataset, args.arch) + '%s/ae_pgd%s_test.pt'

torch.save(ae_clean_erank_list, matrix_ae_clean_path%(method,'erank'))
torch.save(ae_pgd_erank_list, matrix_ae_pgd_path%(method,'erank'))

torch.save(ae_clean_distance_list, matrix_ae_clean_path%(method,'dist'))
torch.save(ae_pgd_distance_list, matrix_ae_pgd_path%(method,'dist'))

